import shutil
import tempfile
from collections import defaultdict

import pybedtools
from numpy import *

import warnings
warnings.filterwarnings('error')


def read_peaks():
    filename = "peaks.HiSeq_StartSeq.gff"
    print("Reading", filename)
    lines = pybedtools.BedTool(filename)
    peaks = defaultdict(list)
    for line in lines:
        name = "%s_%d-%d_%s" % (line.chrom, line.start, line.end, line.strand)
        genes = line.attrs.get("gene")
        if genes is None:
            continue
        genes = genes.split(",")
        for gene in genes:
            peaks[gene].append(name)
    peaks = dict(peaks)
    return peaks

def read_expression_data(peaks):
    filename = "peaks.HiSeq_StartSeq.expression.txt"
    print("Reading", filename)
    stream = open(filename)
    line = next(stream)
    words = line.split()
    assert words[0] == 'peak'
    assert words[1] == 'HiSeq_t00_r1'
    assert words[2] == 'HiSeq_t00_r2'
    assert words[3] == 'HiSeq_t00_r3'
    assert words[4] == 'HiSeq_t01_r1'
    assert words[5] == 'HiSeq_t01_r2'
    assert words[6] == 'HiSeq_t04_r1'
    assert words[7] == 'HiSeq_t04_r2'
    assert words[8] == 'HiSeq_t04_r3'
    assert words[9] == 'HiSeq_t12_r1'
    assert words[10] == 'HiSeq_t12_r2'
    assert words[11] == 'HiSeq_t12_r3'
    assert words[12] == 'HiSeq_t24_r1'
    assert words[13] == 'HiSeq_t24_r2'
    assert words[14] == 'HiSeq_t24_r3'
    assert words[15] == 'HiSeq_t96_r1'
    assert words[16] == 'HiSeq_t96_r2'
    assert words[17] == 'HiSeq_t96_r3'
    assert words[18] == 'StartSeq_SRR7071452'
    assert words[19] == 'StartSeq_SRR7071453'
    assert len(words) == 20
    timepoints = (0, 1, 4, 12, 24, 96)
    indices = [list() for timepoint in timepoints]
    for index, word in enumerate(words[1:]):
        dataset, library = word.split("_", 1)
        if dataset == "HiSeq":
            timepoint, replicate = library.split("_")
            assert replicate in ("r1", "r2", "r3")
            assert timepoint.startswith("t")
            timepoint = int(timepoint[1:])
            indices[timepoints.index(timepoint)].append(index)
        elif dataset == "StartSeq":
            assert library in ("SRR7071452", "SRR7071453")
        else:
            raise Exception("Unknown experiment %s" % experiment)
    for i, timepoint in enumerate(timepoints):
        indices[i] = array(indices[i])
    names = []
    data = []
    for line in stream:
        words = line.split()
        assert len(words) == 20
        name = words[0]
        row = array(words[1:20], float)
        data.append(row)
        names.append(name)
    stream.close()
    data = array(data)
    data *= 1.e6 / sum(data, 0)
    tpms = {}
    for name, row in zip(names, data):
        tpm = zeros(6)
        for i, timepoint in enumerate(timepoints):
            tpm[i] = mean(row[indices[i]])
        tpms[name] = mean(tpm)
    dominant_names = {}
    for gene in peaks:
        maximum = 0
        dominant_name = None
        for name in peaks[gene]:
            tpm = tpms[name]
            if tpm > maximum:
                maximum = tpm
                dominant_name = name
        if dominant_name is not None:
            dominant_names[gene] = dominant_name
    return dominant_names

def write_peaks(names):
    filename = "peaks.HiSeq_StartSeq.gff"
    print("Reading", filename)
    lines = pybedtools.BedTool(filename)
    stream = tempfile.NamedTemporaryFile(delete=False, mode='wt')
    print("Writing", stream.name)
    for line in lines:
        genes = line.attrs.get('gene')
        if genes is not None:
            genes = genes.split(",")
            name = "%s_%d-%d_%s" % (line.chrom, line.start, line.end, line.strand)
            dominant_genes = []
            for gene in genes:
                if names.get(gene) == name:
                    dominant_genes.append(gene)
            if dominant_genes:
                dominant_genes = ",".join(dominant_genes)
                line.attrs['dominant'] = dominant_genes
        stream.write(str(line))
    stream.close()
    print("Renaming %s to %s" % (stream.name, filename))
    shutil.move(stream.name, filename)


peaks = read_peaks()
names = read_expression_data(peaks)
write_peaks(names)
